{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import Variable\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy import stats\n",
    "\n",
    "\n",
    "def generate_real_data_distribution(n_dim, num_samples):\n",
    "    all_data = []\n",
    "    for i in range(num_samples):\n",
    "        x = np.random.uniform(0, 8, n_dim)\n",
    "        y = stats.lognorm.pdf(x, 0.6)\n",
    "        all_data.append(y)\n",
    "    all_data = np.array(all_data)\n",
    "    print('generated data shape: ', all_data.shape)\n",
    "    return all_data\n",
    "\n",
    "\n",
    "def batch_inputs(all_data, batch_size=6):\n",
    "    assert isinstance(all_data, np.ndarray), 'all_data must be numpy array'\n",
    "    batch_x = all_data[np.random.randint(all_data.shape[0], size=batch_size)]\n",
    "    return Variable(torch.from_numpy(batch_x).float())\n",
    "\n",
    "\n",
    "def main():\n",
    "    # 给generator的噪音维数\n",
    "    n_noise_dim = 30\n",
    "    # 真实数据的维度\n",
    "    n_real_data_dim = 256\n",
    "    num_samples = 666\n",
    "    lr_g = 0.001\n",
    "    lr_d = 0.03\n",
    "    batch_size = 6\n",
    "    epochs = 1000\n",
    "\n",
    "    real_data = generate_real_data_distribution(n_real_data_dim, num_samples=num_samples)\n",
    "    print('sample from real data: \\n', real_data[: 10])\n",
    "\n",
    "    g_net = nn.Sequential(\n",
    "        nn.Linear(n_noise_dim, 128),\n",
    "        nn.ReLU(),\n",
    "        nn.Linear(128, n_real_data_dim)\n",
    "    )\n",
    "\n",
    "    d_net = nn.Sequential(\n",
    "        nn.Linear(n_real_data_dim, 128),\n",
    "        nn.ReLU(),\n",
    "        nn.Linear(128, 1),\n",
    "        nn.Sigmoid()\n",
    "    )\n",
    "\n",
    "    opt_d = torch.optim.Adam(d_net.parameters(), lr=lr_d)\n",
    "    opt_g = torch.optim.Adam(g_net.parameters(), lr=lr_g)\n",
    "\n",
    "    for epoch in range(epochs):\n",
    "        for i in range(num_samples // batch_size):\n",
    "            batch_x = batch_inputs(real_data, batch_size)\n",
    "            batch_noise = Variable(torch.randn(batch_size, n_noise_dim))\n",
    "\n",
    "            g_data = g_net(batch_noise)\n",
    "\n",
    "            # 用G判断两个输出分别多大概率是来自真正的画家\n",
    "            prob_fake = d_net(g_data)\n",
    "            prob_real = d_net(batch_x)\n",
    "\n",
    "            # 很显然，mean里面的这部分是一个负值，如果想整体loss变小，必须要变成正直，加一个负号，否则会越来越大\n",
    "            d_loss = -torch.mean(torch.log(prob_real) + torch.log(1 - prob_fake))\n",
    "            # 而g的loss要使得discriminator的prob_fake尽可能小，这样才能骗过它，因此也要加一个负号\n",
    "            g_loss = -torch.mean(torch.log(prob_fake))\n",
    "\n",
    "            opt_d.zero_grad()\n",
    "            d_loss.backward(retain_variables=True)\n",
    "            opt_d.step()\n",
    "\n",
    "            opt_g.zero_grad()\n",
    "            g_loss.backward(retain_variables=True)\n",
    "            opt_g.step()\n",
    "\n",
    "            print('Epoch: {}, batch: {}, d_loss: {}, g_loss: {}'.format(epoch, i, d_loss.data.numpy()[0],\n",
    "                                                                        g_loss.data.numpy()[0]))\n",
    "            \n",
    "    return 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generated data shape:  (666, 256)\n",
      "sample from real data: \n",
      " [[7.55683044e-01 2.41670159e-01 1.89884750e-01 ... 3.37450888e-02\n",
      "  5.43958932e-03 3.02166014e-01]\n",
      " [5.75202223e-04 4.19855628e-04 1.28658725e-03 ... 5.26415899e-03\n",
      "  7.84612499e-01 5.60067281e-04]\n",
      " [3.36227188e-01 2.97554134e-02 2.97621171e-03 ... 3.13171192e-03\n",
      "  2.13534950e-04 1.68821494e-03]\n",
      " ...\n",
      " [2.22679734e-01 1.89255665e-01 1.89344815e-02 ... 4.97409087e-02\n",
      "  1.01896303e-03 4.94456547e-01]\n",
      " [6.26499242e-03 4.54626385e-01 2.22300006e-03 ... 3.37009851e-01\n",
      "  2.06049403e-04 5.64846704e-04]\n",
      " [2.02580788e-01 2.00556696e-03 3.34151837e-03 ... 1.08539283e-02\n",
      "  8.35018270e-03 3.54468091e-03]]\n"
     ]
    },
    {
     "ename": "TypeError",
     "evalue": "backward() got an unexpected keyword argument 'retain_variables'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-8-c7bc734e5e35>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m__name__\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'__main__'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m     \u001b[0mmain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m<ipython-input-7-79da5c3dc63f>\u001b[0m in \u001b[0;36mmain\u001b[0;34m()\u001b[0m\n\u001b[1;32m     72\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     73\u001b[0m             \u001b[0mopt_d\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 74\u001b[0;31m             \u001b[0md_loss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mretain_variables\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     75\u001b[0m             \u001b[0mopt_d\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     76\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mTypeError\u001b[0m: backward() got an unexpected keyword argument 'retain_variables'"
     ]
    }
   ],
   "source": [
    "if __name__ == '__main__':\n",
    "    main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.4 64-bit ('base': conda)",
   "language": "python",
   "name": "python37464bitbaseconda5073072be3c4479bad20c57a46be46a1"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
